{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Printing with CuTe DSL\n",
    "\n",
    "This notebook demonstrates the different ways to print values in CuTe and explains the important distinction between static (compile-time) and dynamic (runtime) values.\n",
    "\n",
    "## Key Concepts\n",
    "- Static values: Known at compile time\n",
    "- Dynamic values: Only known at runtime\n",
    "- Different printing methods for different scenarios\n",
    "- Layout representation in CuTe\n",
    "- Tensor visualization and formatting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cutlass\n",
    "import cutlass.cute as cute\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Print Example Function\n",
    "\n",
    "The `print_example` function demonstrates several important concepts:\n",
    "\n",
    "### 1. Python's `print` vs CuTe's `cute.printf`\n",
    "- `print`: Can only show static values at compile time\n",
    "- `cute.printf`: Can display both static and dynamic values at runtime\n",
    "\n",
    "### 2. Value Types\n",
    "- `a`: Dynamic `Int32` value (runtime)\n",
    "- `b`: Static `Constexpr[int]` value (compile-time)\n",
    "\n",
    "### 3. Layout Printing\n",
    "Shows how layouts are represented differently in static vs dynamic contexts:\n",
    "- Static context: Unknown values shown as `?`\n",
    "- Dynamic context: Actual values displayed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "@cute.jit\n",
    "def print_example(a: cutlass.Int32, b: cutlass.Constexpr[int]):\n",
    "    \"\"\"\n",
    "    Demonstrates different printing methods in CuTe and how they handle static vs dynamic values.\n",
    "\n",
    "    This example shows:\n",
    "    1. How Python's `print` function works with static values at compile time but can't show dynamic values\n",
    "    2. How `cute.printf` can display both static and dynamic values at runtime\n",
    "    3. The difference between types in static vs dynamic contexts\n",
    "    4. How layouts are represented in both printing methods\n",
    "\n",
    "    Args:\n",
    "        a: A dynamic Int32 value that will be determined at runtime\n",
    "        b: A static (compile-time constant) integer value\n",
    "    \"\"\"\n",
    "    # Use Python `print` to print static information\n",
    "    print(\">>>\", b)  # => 2\n",
    "    # `a` is dynamic value\n",
    "    print(\">>>\", a)  # => ?\n",
    "\n",
    "    # Use `cute.printf` to print dynamic information\n",
    "    cute.printf(\">?? {}\", a)  # => 8\n",
    "    cute.printf(\">?? {}\", b)  # => 2\n",
    "\n",
    "    print(\">>>\", type(a))  # => <class 'cutlass.Int32'>\n",
    "    print(\">>>\", type(b))  # => <class 'int'>\n",
    "\n",
    "    layout = cute.make_layout((a, b))\n",
    "    print(\">>>\", layout)  # => (?,2):(1,?)\n",
    "    cute.printf(\">?? {}\", layout)  # => (8,2):(1,8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compile and Run\n",
    "\n",
    "**Direct Compilation and Run**\n",
    "  - `print_example(cutlass.Int32(8), 2)`\n",
    "  - Compiles and runs in one step will execute both static and dynamic print\n",
    "    * `>>>` stands for static print\n",
    "    * `>??` stands for dynamic print"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> 2\n",
      ">>> ?\n",
      ">>> Int32\n",
      ">>> <class 'int'>\n",
      ">>> (?,2):(1,?)\n",
      ">?? 8\n",
      ">?? 2\n",
      ">?? (8,2):(1,8)\n"
     ]
    }
   ],
   "source": [
    "print_example(cutlass.Int32(8), 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compile Function\n",
    "\n",
    "When compiles the function with `cute.compile(print_example, cutlass.Int32(8), 2)`, Python interpreter \n",
    "traces code and only evaluate static expression and print static information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">>> 2\n",
      ">>> ?\n",
      ">>> Int32\n",
      ">>> <class 'int'>\n",
      ">>> (?,2):(1,?)\n"
     ]
    }
   ],
   "source": [
    "print_example_compiled = cute.compile(print_example, cutlass.Int32(8), 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Call compiled function\n",
    "\n",
    "Only print out runtime information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ">?? 8\n",
      ">?? 2\n",
      ">?? (8,2):(1,8)\n"
     ]
    }
   ],
   "source": [
    "print_example_compiled(cutlass.Int32(8))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Format String Example\n",
    "\n",
    "The `format_string_example` function shows an important limitation:\n",
    "- F-strings in CuTe are evaluated at compile time\n",
    "- This means dynamic values won't show their runtime values in f-strings\n",
    "- Use `cute.printf` when you need to see runtime values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Direct run output:\n",
      "a: ?, b: 2\n",
      "layout: (?,2):(1,?)\n"
     ]
    }
   ],
   "source": [
    "@cute.jit\n",
    "def format_string_example(a: cutlass.Int32, b: cutlass.Constexpr[int]):\n",
    "    \"\"\"\n",
    "    Format string is evaluated at compile time.\n",
    "    \"\"\"\n",
    "    print(f\"a: {a}, b: {b}\")\n",
    "\n",
    "    layout = cute.make_layout((a, b))\n",
    "    print(f\"layout: {layout}\")\n",
    "\n",
    "\n",
    "print(\"Direct run output:\")\n",
    "format_string_example(cutlass.Int32(8), 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Printing Tensor Examples\n",
    "\n",
    "CuTe provides specialized functionality for printing tensors through the `print_tensor` operation. The `cute.print_tensor` takes the following parameter:\n",
    "- `Tensor` (required): A CuTe tensor object that you want to print. The tensor must support load and store operations\n",
    "- `verbose` (optional, default=False): A boolean flag that controls the level of detail in the output. When set to True, it will print indices details for each element in the tensor.\n",
    "\n",
    "Below example code shows the difference between verbose ON and OFF, and how to print a sub range of the given tensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cutlass.cute.runtime import from_dlpack\n",
    "\n",
    "\n",
    "@cute.jit\n",
    "def print_tensor_basic(x: cute.Tensor):\n",
    "    # Print the tensor\n",
    "    print(\"Basic output:\")\n",
    "    cute.print_tensor(x)\n",
    "\n",
    "\n",
    "@cute.jit\n",
    "def print_tensor_verbose(x: cute.Tensor):\n",
    "    # Print the tensor with verbose mode\n",
    "    print(\"Verbose output:\")\n",
    "    cute.print_tensor(x, verbose=True)\n",
    "\n",
    "\n",
    "@cute.jit\n",
    "def print_tensor_slice(x: cute.Tensor, coord: tuple):\n",
    "    # slice a 2D tensor from the 3D tensor\n",
    "    sliced_data = cute.slice_(x, coord)\n",
    "    y = cute.make_rmem_tensor(sliced_data.layout, sliced_data.element_type)\n",
    "    # Convert to TensorSSA format by loading the sliced data into the fragment\n",
    "    y.store(sliced_data.load())\n",
    "    print(\"Slice output:\")\n",
    "    cute.print_tensor(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The default `cute.print_tensor` will output CuTe tensor with datatype, storage space, CuTe layout information, and print data in torch-style format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Basic output:\n",
      "tensor(raw_ptr(0x000000000a5f1d50: f32, generic, align<4>) o (4,3,2):(6,2,1), data=\n",
      "       [[[ 0.000000,  2.000000,  4.000000, ],\n",
      "         [ 6.000000,  8.000000,  10.000000, ],\n",
      "         [ 12.000000,  14.000000,  16.000000, ],\n",
      "         [ 18.000000,  20.000000,  22.000000, ]],\n",
      "\n",
      "        [[ 1.000000,  3.000000,  5.000000, ],\n",
      "         [ 7.000000,  9.000000,  11.000000, ],\n",
      "         [ 13.000000,  15.000000,  17.000000, ],\n",
      "         [ 19.000000,  21.000000,  23.000000, ]]])\n"
     ]
    }
   ],
   "source": [
    "def tensor_print_example1():\n",
    "    shape = (4, 3, 2)\n",
    "\n",
    "    # Creates [0,...,23] and reshape to (4, 3, 2)\n",
    "    data = np.arange(24, dtype=np.float32).reshape(*shape)\n",
    "\n",
    "    print_tensor_basic(from_dlpack(data))\n",
    "\n",
    "\n",
    "tensor_print_example1()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The verbosed print will show coodination details of each element in the tensor. The below example shows how we index element in a 2D 4x3 tensor space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Verbose output:\n",
      "tensor(raw_ptr(0x000000000a814cc0: f32, generic, align<4>) o (4,3):(3,1), data= (\n",
      "\t(0,0)= 0.000000\n",
      "\t(0,1)= 1.000000\n",
      "\t(0,2)= 2.000000\n",
      "\t(1,0)= 3.000000\n",
      "\t(1,1)= 4.000000\n",
      "\t(1,2)= 5.000000\n",
      "\t(2,0)= 6.000000\n",
      "\t(2,1)= 7.000000\n",
      "\t(2,2)= 8.000000\n",
      "\t(3,0)= 9.000000\n",
      "\t(3,1)= 10.000000\n",
      "\t(3,2)= 11.000000\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "def tensor_print_example2():\n",
    "    shape = (4, 3)\n",
    "\n",
    "    # Creates [0,...,11] and reshape to (4, 3)\n",
    "    data = np.arange(12, dtype=np.float32).reshape(*shape)\n",
    "\n",
    "    print_tensor_verbose(from_dlpack(data))\n",
    "\n",
    "\n",
    "tensor_print_example2()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To print a subset elements in the given Tensor, we can use cute.slice_ to select a range of the given tensor, load them into register and then print the values with `cute.print_tensor`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Slice output:\n",
      "tensor(raw_ptr(0x00007ffeeae1fc60: f32, rmem, align<32>) o (4):(3), data=\n",
      "       [ 0.000000, ],\n",
      "       [ 3.000000, ],\n",
      "       [Slice output:\n",
      " 6.000000, ],\n",
      "       [ 9.000000, ])\n",
      "tensor(raw_ptr(0x00007ffeeae1fc60: f32, rmem, align<32>) o (3):(1), data=\n",
      "       [ 3.000000, ],\n",
      "       [ 4.000000, ],\n",
      "       [ 5.000000, ])\n"
     ]
    }
   ],
   "source": [
    "def tensor_print_example3():\n",
    "    shape = (4, 3)\n",
    "\n",
    "    # Creates [0,...,11] and reshape to (4, 3)\n",
    "    data = np.arange(12, dtype=np.float32).reshape(*shape)\n",
    "\n",
    "    print_tensor_slice(from_dlpack(data), (None, 0))\n",
    "    print_tensor_slice(from_dlpack(data), (1, None))\n",
    "\n",
    "\n",
    "tensor_print_example3()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To print the tensor in device memory, you can use `cute.print_tensor` within CuTe JIT kernels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "@cute.kernel\n",
    "def print_tensor_gpu(src: cute.Tensor):\n",
    "    print(src)\n",
    "    cute.print_tensor(src)\n",
    "\n",
    "\n",
    "@cute.jit\n",
    "def print_tensor_host(src: cute.Tensor):\n",
    "    print_tensor_gpu(src).launch(grid=(1, 1, 1), block=(1, 1, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor<ptr<f32, gmem> o (4,3):(3,1)>\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(raw_ptr(0x00007f5f81200400: f32, gmem, align<4>) o (4,3):(3,1), data=\n",
      "       [[-0.690547, -0.274619, -1.659539, ],\n",
      "        [-1.843524, -1.648711,  1.163431, ],\n",
      "        [-0.716668, -1.900705,  0.592515, ],\n",
      "        [ 0.711333, -0.552422,  0.860237, ]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "def tensor_print_example4():\n",
    "    a = torch.randn(4, 3, device=\"cuda\")\n",
    "    cutlass.cuda.initialize_cuda_context()\n",
    "    print_tensor_host(from_dlpack(a))\n",
    "\n",
    "\n",
    "tensor_print_example4()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Currently, `cute.print_tensor` only supports tensor with integer data types and `Float16`/`Float32`/`Float64` floating point data types. We will support more data types in the future."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
